{
  "nbformat": 4,
  "nbformat_minor": 0,
  "metadata": {
    "colab": {
      "provenance": []
    },
    "kernelspec": {
      "name": "python3",
      "display_name": "Python 3"
    },
    "language_info": {
      "name": "python"
    }
  },
  "cells": [
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "bs2_45zZdqwt"
      },
      "outputs": [],
      "source": [
        "from transformers import BertTokenizerFast, BertForSequenceClassification\n",
        "\n",
        "model_path= 'bert-base-uncased'\n",
        "\n",
        "tokenizer = BertTokenizerFast.from_pretrained(model_path)"
      ]
    },
    {
      "cell_type": "code",
      "source": [
        "from datasets import load_dataset\n",
        "\n",
        "imdb_train = load_dataset(\"imdb\", split=\"train[:2000]+train[-2000:]\")\n",
        "\n",
        "imdb_test = load_dataset(\"imdb\", split=\"test[:500]+test[-500:]\")\n",
        "\n",
        "imdb_val = load_dataset(\"imdb\", split=\"test[500:1000]+test[-1000:-500]\")\n",
        "\n",
        "imdb_train.shape, imdb_test.shape, imdb_val.shape\n",
        "\n",
        "OUTPUT: ((4000, 2), (1000, 2), (1000, 2))\n",
        "\n",
        "\n",
        "def tokenize_it(e):\n",
        "\n",
        "    return tokenizer(e[\"text\"], padding=True, truncation=True)\n",
        "\n",
        "\n",
        "enc_train = imdb_train.map(tokenize_it, batched=True, batch_size=1000)\n",
        "\n",
        "enc_test = imdb_test.map(tokenize_it, batched=True, batch_size=1000)\n",
        "\n",
        "enc_val = imdb_val.map(tokenize_it, batched=True, batch_size=1000)"
      ],
      "metadata": {
        "id": "HjSJ44Cdd4cM"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "dataset_for_adaptation= load_dataset('imdb', split=\"train\")\n",
        "\n",
        "imdb_sentences=dataset_for_adaptation[\"text\"]\n",
        "\n",
        "train_sentences=imdb_sentences[:20000]\n",
        "\n",
        "dev_sentences=imdb_sentences[20000:]"
      ],
      "metadata": {
        "id": "mZamQaBkeaQJ"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "from transformers import AutoModelForMaskedLM, AutoTokenizer\n",
        "\n",
        "model = AutoModelForMaskedLM.from_pretrained(model_path)"
      ],
      "metadata": {
        "id": "QT3WOeBqef52"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "batch_size = 16\n",
        "\n",
        "num_train_epochs = 15\n",
        "\n",
        "max_length = 100\n",
        "\n",
        "mlm_prob = 0.25"
      ],
      "metadata": {
        "id": "AlfK1Hk0enAg"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "class TokenizedSentencesDataset:\n",
        "\n",
        "    def __init__(self, sentences, tokenizer, max_length, cache_tokenization=False):\n",
        "\n",
        "        self.tokenizer = tokenizer\n",
        "\n",
        "        self.sentences = sentences\n",
        "\n",
        "        self.max_length = max_length\n",
        "\n",
        "        self.cache_tokenization = cache_tokenization\n",
        "\n",
        "    def __getitem__(self, item):\n",
        "\n",
        "        if not self.cache_tokenization:\n",
        "\n",
        "            return self.tokenizer(\n",
        "                self.sentences[item],\n",
        "                add_special_tokens=True,\n",
        "                truncation=True,\n",
        "                max_length=self.max_length,\n",
        "                return_special_tokens_mask=True,\n",
        "            )\n",
        "\n",
        "        if isinstance(self.sentences[item], str):\n",
        "\n",
        "            self.sentences[item] = self.tokenizer(\n",
        "                self.sentences[item],\n",
        "                add_special_tokens=True,\n",
        "                truncation=True,\n",
        "                max_length=self.max_length,\n",
        "                return_special_tokens_mask=True,\n",
        "            )\n",
        "\n",
        "        return self.sentences[item]\n",
        "\n",
        "    def __len__(self):\n",
        "\n",
        "        return len(self.sentences)"
      ],
      "metadata": {
        "id": "YDFKAt4PepQt"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "train_dataset = TokenizedSentencesDataset(train_sentences,\n",
        "\n",
        " tokenizer, max_length)\n",
        "\n",
        "dev_dataset = TokenizedSentencesDataset(dev_sentences,\n",
        "\n",
        " \t\t\t\ttokenizer, max_length)"
      ],
      "metadata": {
        "id": "sCSG8aGyezFB"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "data_collator = DataCollatorForLanguageModeling(\n",
        "    tokenizer=tokenizer, mlm=True, mlm_probability=mlm_prob\n",
        ")\n",
        "\n",
        "training_args = TrainingArguments(\n",
        "    num_train_epochs=num_train_epochs,\n",
        "    evaluation_strategy=\"steps\",\n",
        "    per_device_train_batch_size=batch_size,\n",
        "    prediction_loss_only=True,\n",
        "    fp16=True,\n",
        ")\n",
        "\n",
        "trainer = Trainer(\n",
        "    model=model,\n",
        "    args=training_args,\n",
        "    data_collator=data_collator,\n",
        "    train_dataset=train_dataset,\n",
        "    eval_dataset=dev_dataset,\n",
        ")"
      ],
      "metadata": {
        "id": "0v0dJzYie3Hu"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "trainer.train()"
      ],
      "metadata": {
        "id": "VS4K_ytGe_iR"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "adapted_model_path=\"adapted-bert\"\n",
        "\n",
        "model.save_pretrained(adapted_model_path)\n",
        "\n",
        "tokenizer.save_pretrained(adapted_model_path)"
      ],
      "metadata": {
        "id": "so7XaUPJfDSj"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "model_path = \"adapted-bert\"  # 1) Adapted model\n",
        "\n",
        "# model_path= \"bert-base-uncased\" # 2)vanilla bert\n",
        "\n",
        "model = BertForSequenceClassification.from_pretrained(\n",
        "    model_path, id2label={0: \"NEG\", 1: \"POS\"}, label2id={\"NEG\": 0, \"POS\": 1}\n",
        ")\n"
      ],
      "metadata": {
        "id": "raNckyE7fG6p"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [],
      "metadata": {
        "id": "oliqQVdufO38"
      },
      "execution_count": null,
      "outputs": []
    }
  ]
}